/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "graph_optimizer/graph_fusion/graph_replace.h"

#include <climits>
#include <sstream>
#include <stack>

#include "external/graph/operator_factory.h"
#include "graph/utils/op_desc_utils.h"
#include "graph_optimizer/fusion_common/graph_node_map_util.h"
#include "register/graph_optimizer/fusion_common/graph_pass_util.h"

using std::map;
using std::stack;
using std::string;
using std::vector;

namespace fe {
const string NEED_INFER = "isNeedInfer";
GraphReplace::GraphReplace(shared_ptr<FEOpsKernelInfoStore> ops_kernel_info_store_ptr)
    : ops_kernel_info_store_ptr_(ops_kernel_info_store_ptr) {}
GraphReplace::~GraphReplace() {}

Status GraphReplace::ReplaceGraph(vector<GraphMatchResult> &match_results, const FusionRulePattern &fusion_rule_pattern,
                                  ge::ComputeGraph &graph) {
  size_t result_num = match_results.size();
  int32_t effect_times = 0;
  FusionInfo fusion_info(graph.GetSessionID(), to_string(graph.GetGraphID()), fusion_rule_pattern.GetRuleName(), 0, 0);
  for (size_t i = 0; i < result_num; ++i) {
    GraphMatchResult match_result = match_results[i];
    string rule_name = fusion_rule_pattern.GetRuleName();
    UpdateMatchedOuterAnchor(match_result, rule_name);
    map<FusionRuleNodePtr, ge::NodePtr> fusion_graph = {};
    Status ret = CreateFusionNodes(fusion_rule_pattern, match_result.origin_nodes, fusion_graph, graph);
    if (ret != SUCCESS) {
      REPORT_FE_ERROR(
          "[GraphOpt][RunFusionRule][RplGph] fusion rule name[%s] No.%zu sub_graph, fusion sub_graph create failed.",
          fusion_rule_pattern.GetRuleName().c_str(), (i + 1));
      return ret;
    }

    ret = UpdateAttr(match_result.origin_nodes, fusion_graph);
    if (ret != SUCCESS) {
      REPORT_FE_ERROR(
          "[GraphOpt][RunFusionRule][RplGph] fusion rule name[%s] No.%zu sub_graph, update attribute value failed.",
          fusion_rule_pattern.GetRuleName().c_str(), (i + 1));
      return ret;
    }

    ret = UpdateSpecialAttr(match_result.origin_nodes, fusion_graph);
    if (ret != SUCCESS) {
      REPORT_FE_ERROR(
          "[GraphOpt][RunFusionRule][RplGph] fusion rule name[%s] No.%zu sub_graph, update special \
          attribute value failed.",
          fusion_rule_pattern.GetRuleName().c_str(), (i + 1));
      return ret;
    }

    if (CheckFusionNode(match_result, fusion_rule_pattern, fusion_graph) == FAILED) {
      FE_LOGW("fusion rule name[%s] No.%zu time fusion Failed, fusion node not support.",
              fusion_rule_pattern.GetRuleName().c_str(), (i + 1));
      if (DeleteNodes(fusion_graph, fusion_rule_pattern.GetFusionRuleNodes(), graph) == FAILED) {
        REPORT_FE_ERROR(
            "[GraphOpt][RunFusionRule][RplGph] fusion rule name[%s] No.[%zu] sub_graph, delete fusion node failed.",
            fusion_rule_pattern.GetRuleName().c_str(), (i + 1));
        return GRAPH_REPLACE_DELETE_NODE_FAILED;
      }
      return GRAPH_REPLACE_CHECKSUPPORTED_FAILED;
    }

    // Record fusion nodes
    RecordFusionNodes(fusion_graph, match_result);

    if (Replace(match_result, fusion_graph, fusion_rule_pattern, graph) != SUCCESS) {
      REPORT_FE_ERROR(
          "[GraphOpt][RunFusionRule][RplGph] fusion rule name[%s] No.[%zu] sub_graph, edges or nodes replace failed.",
          fusion_rule_pattern.GetRuleName().c_str(), (i + 1));
      return FAILED;
    }

    // Post fusion process, eg. record original name, output anchor map
    PostFusion(match_result);

    FE_LOGD("fusion rule name[%s] No. [%d] time fusion Success", fusion_rule_pattern.GetRuleName().c_str(), (i + 1));
    effect_times++;
  }
  // get effect times
  fusion_info.SetEffectTimes(effect_times);
  FusionStatisticRecorder::Instance().UpdateGraphFusionEffectTimes(fusion_info);
  FE_LOGD("SessionId %d GraphId %d fusion rule name:%s fusion Success, %d times take effect", graph.GetSessionID(),
          graph.GetGraphID(), fusion_rule_pattern.GetRuleName().c_str(), effect_times);
  return SUCCESS;
}

void GraphReplace::UpdateOuterInputs(const string &pattern_name, GraphMatchResult &match_result,
                                     std::map<FusionRuleAnchorPtr, ge::AnchorPtr> &outer_inputs) {
  for (auto &origin_anchor_map_pair : match_result.origin_outer_inputs) {
    FusionRuleAnchorPtr rule_anchor_ptr = origin_anchor_map_pair.first;
    for (auto &peer_out_anchor : origin_anchor_map_pair.second->GetPeerAnchors()) {
      ge::NodePtr node = peer_out_anchor->GetOwnerNode();
      if (match_result.origin_nodes_set.find(node) == match_result.origin_nodes_set.end()) {
        outer_inputs.emplace(rule_anchor_ptr, peer_out_anchor);
        FE_LOGD("outerInputs rule_anchor:%s:%d, new_anchor:%s:%d, pattern_name:%s",
                rule_anchor_ptr->GetOwnerNode()->GetNodeName().c_str(), rule_anchor_ptr->GetAnchorIdx(),
                peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(), pattern_name.c_str());

        auto iter = match_result.outer_inputs_set.find(rule_anchor_ptr);
        if (iter == match_result.outer_inputs_set.end()) {
          std::set<ge::AnchorPtr> graph_achors = {peer_out_anchor};
          match_result.outer_inputs_set.emplace(rule_anchor_ptr, graph_achors);
          FE_LOGD("new rule_anchor:%s:%d, new_anchor:%s:%d, pattern_name:%s",
                  rule_anchor_ptr->GetOwnerNode()->GetNodeName().c_str(), rule_anchor_ptr->GetAnchorIdx(),
                  peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(), pattern_name.c_str());
        } else {
          iter->second.insert(peer_out_anchor);
          FE_LOGD("has rule_anchor:%s:%d, new_anchor:%s:%d, pattern_name:%s",
                  rule_anchor_ptr->GetOwnerNode()->GetNodeName().c_str(), rule_anchor_ptr->GetAnchorIdx(),
                  peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(), pattern_name.c_str());
        }
      }
    }
  }
}

Status GraphReplace::UpdateMatchedOuterAnchor(GraphMatchResult &match_result, string &pattern_name) {
  if (match_result.origin_outer_inputs.empty() || match_result.origin_outer_outputs.empty()) {
    FE_LOGW("Not get origin outer input and output, pattern_name[%s]", pattern_name.c_str());
    return SUCCESS;
  }

  // update OuterInputs anchors
  std::map<FusionRuleAnchorPtr, ge::AnchorPtr> outer_inputs;
  UpdateOuterInputs(pattern_name, match_result, outer_inputs);

  // update outer_outputs anchors
  std::map<FusionRuleAnchorPtr, std::set<ge::AnchorPtr>> outer_outputs;
  for (auto &origin_anchor_map : match_result.origin_outer_outputs) {
    std::set<ge::AnchorPtr> anchor_set;
    std::set<ge::AnchorPtr> origin_anchor_set = origin_anchor_map.second;
    for (auto &origin_anchor : origin_anchor_set) {
      for (auto &peer_in_anchor : origin_anchor->GetPeerAnchors()) {
        ge::NodePtr node = peer_in_anchor->GetOwnerNode();
        if (match_result.origin_nodes_set.find(node) == match_result.origin_nodes_set.end()) {
          anchor_set.insert(peer_in_anchor);
          FE_LOGD("output rule_anchor:%s:%d, new_anchor:%s:%d, pattern_name:%s",
                  origin_anchor_map.first->GetOwnerNode()->GetNodeName().c_str(),
                  origin_anchor_map.first->GetAnchorIdx(), peer_in_anchor->GetOwnerNode()->GetName().c_str(),
                  peer_in_anchor->GetIdx(), pattern_name.c_str());
        }
      }
    }
    outer_outputs.insert(make_pair(origin_anchor_map.first, anchor_set));
  }

  if (outer_outputs.size() != match_result.outer_outputs.size() ||
      outer_inputs.size() != match_result.outer_inputs.size()) {
    FE_LOGD("patternName:%s, two rules are continuous", pattern_name.c_str());
  }

  match_result.outer_inputs = outer_inputs;
  match_result.outer_outputs = outer_outputs;
  return SUCCESS;
}

Status GraphReplace::CreateFusionNodes(const FusionRulePattern &fusion_rule_pattern,
                                       const map<FusionRuleNodePtr, ge::NodePtr> &origin_sub_graph,
                                       map<FusionRuleNodePtr, ge::NodePtr> &fusion_graph, ge::ComputeGraph &graph) {
  NodeMapInfoPtr node_map_info = nullptr;
  (void)GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph);
  set<FusionRuleNodePtr> fusion_rule_node_set = fusion_rule_pattern.GetFusionRuleNodes();
  set<FusionRuleNodePtr> origin_rule_node_set = fusion_rule_pattern.GetOriginRuleNodes();
  for (auto &fusion_rule_node : fusion_rule_node_set) {
    // The fusioned node type is the same as the pre-fusion node type;
    // create fusion node with pre-fusion node opdesc;
    // if pre-fusion node can't found, create new opdesc and create fusion node
    // with new opdesc
    ge::NodePtr node = nullptr;
    // using node name of fusion rule to find whether this fusion node is in
    // matched subgraph
    ge::NodePtr origin_node = FindSameNode(fusion_rule_node, origin_sub_graph);
    if (origin_node != nullptr) {
      ge::GeTensorDesc tensor_desc;
      ge::OpDescPtr op_desc = ge::AttrUtils::CopyOpDesc(origin_node->GetOpDesc());
      // add input_opdesc
      size_t input_opdesc_count = fusion_rule_node->GetInputDataAnchors().size();
      for (size_t i = op_desc->GetInputsSize(); i < input_opdesc_count; ++i) {
        op_desc->AddInputDesc(tensor_desc);
      }
      size_t output_opdesc_count = fusion_rule_node->GetOutputDataAnchors().size();
      for (size_t i = op_desc->GetOutputsSize(); i < output_opdesc_count; ++i) {
        op_desc->AddOutputDesc(tensor_desc);
      }
      node = graph.AddNode(op_desc);
      FE_CHECK(node == nullptr,
               REPORT_FE_ERROR("[GraphOpt][RunFusionRule][Replace] create fusion node with pre-fusion \
               node opdesc failed"),
               return GRAPH_REPLACE_CREATE_FUSION_NODES_FAILED);

      fusion_graph[fusion_rule_node] = node;
      GraphPassUtil::AddNodeFromOpTypeMap(node_map_info, node);
      continue;
    }
    string node_name = CreateNodeName(origin_sub_graph, fusion_rule_pattern, fusion_rule_node->GetNodeType());
    node = CreateNode(fusion_rule_node, node_name, graph);
    FE_CHECK(node == nullptr, REPORT_FE_ERROR("[GraphOpt][RunFusionRule][Replace] create fusion node with new opdesc \
             failed"),
             return GRAPH_REPLACE_CREATE_FUSION_NODES_FAILED);

    fusion_graph[fusion_rule_node] = node;
  }
  return SUCCESS;
}

Status GraphReplace::UpdateSpecialAttr(const map<FusionRuleNodePtr, ge::NodePtr> &origin_sub_graph,
                                       const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph) {
  vector<string> spec_attr = {"_stream_label"};
  for (auto &attr_name : spec_attr) {
    for (auto &ori_node : origin_sub_graph) {
      ge::OpDescPtr op_desc = ori_node.second->GetOpDesc();
      if (!ge::AttrUtils::HasAttr(op_desc, attr_name)) {
        FE_LOGD("node %s does not have attr %s", op_desc->GetName().c_str(), attr_name.c_str());
        continue;
      }

      ge::GeAttrValue attr_value;
      if (op_desc->GetAttr(attr_name, attr_value) == ge::GRAPH_FAILED) {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][UpdSpclAttr] get attr %s from node %s error",
                        op_desc->GetName().c_str(), attr_name.c_str());
        return GRAPH_REPLACE_UPDATE_ATTR_FAILED;
      }

      for (auto &fusion_item : fusion_sub_graph) {
        ge::OpDescPtr fusion_op_desc = fusion_item.second->GetOpDesc();
        if (fusion_op_desc->SetAttr(attr_name, attr_value) == ge::GRAPH_FAILED) {
          REPORT_FE_ERROR("[GraphOpt][RunFusionRule][UpdSpclAttr] set attr %s to node %s error", attr_name.c_str(),
                          fusion_op_desc->GetName().c_str());
          return GRAPH_REPLACE_UPDATE_ATTR_FAILED;
        }
      }
      break;
    }
  }
  return SUCCESS;
}

Status GraphReplace::UpdateAttr(const map<FusionRuleNodePtr, ge::NodePtr> &origin_sub_graph,
                                const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph) {
  for (auto &fusion_item : fusion_sub_graph) {
    const map<string, FusionRuleAttrValuePtr> &attributes = fusion_item.first->GetAttributes();
    ge::OpDescPtr fusion_opdesc = fusion_item.second->GetOpDesc();

    for (auto &attribute : attributes) {
      string fusion_node_attr_name = attribute.first;
      FusionRuleAttrValuePtr attr_value_ptr = attribute.second;
      ge::GeAttrValue attr_value;
      if (attr_value_ptr->IsFusionRuleAttr()) {
        FusionRuleAttr fusion_rule_attr = attr_value_ptr->GetRuleNodeAttrValue();
        if (origin_sub_graph.find(attr_value_ptr->GetOwnerNode()) == origin_sub_graph.end()) {
          REPORT_FE_ERROR("[GraphOpt][RunFusionRule][UpdAttr] The node[%s] does not in origin SubGraph",
                          attr_value_ptr->GetOwnerNode()->GetNodeName().c_str());
          return GRAPH_REPLACE_UPDATE_ATTR_FAILED;
        }

        ge::NodePtr node = origin_sub_graph.at(attr_value_ptr->GetOwnerNode());
        string origin_node_attr_name = fusion_rule_attr.attr_name;
        ge::OpDescPtr op_desc = node->GetOpDesc();

        if (!ge::AttrUtils::HasAttr(op_desc, origin_node_attr_name)) {
          REPORT_FE_ERROR("[GraphOpt][RunFusionRule][UpdAttr] The node[%s] does not have attr[%s]",
                          node->GetName().c_str(), origin_node_attr_name.c_str());
          return GRAPH_REPLACE_UPDATE_ATTR_FAILED;
        }

        if (op_desc->GetAttr(origin_node_attr_name, attr_value) == ge::GRAPH_FAILED) {
          REPORT_FE_ERROR("[GraphOpt][RunFusionRule][UpdAttr] get attr[%s] from node[%s] error",
                          origin_node_attr_name.c_str(), node->GetName().c_str());
          return GRAPH_REPLACE_UPDATE_ATTR_FAILED;
        }
      } else {
        attr_value = attr_value_ptr->GetFixAttrValue();
      }

      if (fusion_opdesc->SetAttr(fusion_node_attr_name, attr_value) == ge::GRAPH_FAILED) {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][UpdAttr] set attr[%s] to node[%s] error",
                        fusion_node_attr_name.c_str(), fusion_opdesc->GetName().c_str());
        return GRAPH_REPLACE_UPDATE_ATTR_FAILED;
      }

      // this attribute serves as a sign of infering at Graph Engine
      if (!ge::AttrUtils::SetBool(fusion_opdesc, NEED_INFER, true)) {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][UpdAttr] set attr[%s] to node[%s] error", NEED_INFER.c_str(),
                        fusion_opdesc->GetName().c_str());
        return GRAPH_REPLACE_UPDATE_ATTR_FAILED;
      }
    }
  }
  return SUCCESS;
}

void GraphReplace::RecordFusionNodes(map<FusionRuleNodePtr, ge::NodePtr> &fusion_graph,
                                     GraphMatchResult &match_result) {
  map<FusionRuleNodePtr, ge::NodePtr>::iterator it;
  for (it = fusion_graph.begin(); it != fusion_graph.end(); it++) {
    match_result.fusion_nodes.push_back(it->second);
  }
}

Status GraphReplace::Replace(GraphMatchResult &match_result,
                             const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph,
                             const FusionRulePattern &fusion_rule_pattern, ge::ComputeGraph &graph) {
  if (DeleteNodes(match_result.origin_nodes, fusion_rule_pattern.GetOriginRuleNodes(), graph) == FAILED) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][Replace] remove pre-fusion nodes error");
    return GRAPH_REPLACE_DELETE_NODE_FAILED;
  }

  for (auto fusion_item = fusion_sub_graph.begin(); fusion_item != fusion_sub_graph.end(); ++fusion_item) {
    FusionRuleNodePtr rule_node = fusion_item->first;
    ge::NodePtr fusion_node = fusion_item->second;

    // replace output anchors
    if (ReplaceOutputAnchors(rule_node, fusion_node, match_result.outer_outputs, fusion_sub_graph) == FAILED) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][Replace] replace output anchors error");
      return GRAPH_REPLACE_REPLACE_OUTPUT_FAILED;
    }
  }
  return SUCCESS;
}

void GraphReplace::PostFusion(GraphMatchResult &match_result) {
  SetDataDumpAttr(match_result);
  RecordOriginOpNames(match_result);
}

void GraphReplace::RecordOriginOpNames(GraphMatchResult &match_result) {
  // Set origin op names from origin nodes
  vector<string> origin_op_names;
  vector<string> old_op_names;

  std::vector<ge::NodePtr> original_nodes;
  for (auto const &it : match_result.origin_nodes_set) {
    original_nodes.push_back(it);
  }

  for (auto node : match_result.fusion_nodes) {
    GraphPassUtil::RecordOriginalNames(original_nodes, node);
  }

  if (match_result.fusion_nodes.size() > 1) {
    bool is_multi_op = true;
    for (ge::NodePtr &node : match_result.fusion_nodes) {
      ge::AttrUtils::SetBool(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_IS_MULTIOP, is_multi_op);
    }
  }
}

void GraphReplace::SetDataDumpAttr(GraphMatchResult &match_result) {
  // Get rule output anchor form match_result.outer_output
  for (auto &map_outer_output : match_result.outer_outputs) {
    FusionRuleAnchorPtr rule_anchor = map_outer_output.first;
    if (rule_anchor->GetPeerAnchors().empty()) {
      continue;
    }
    // Get output node anchor peer node output anchor idx
    int32_t rule_out_anchor_idx = rule_anchor->GetPeerAnchors().at(0)->GetAnchorIdx();
    if (rule_out_anchor_idx == -1) {
      continue;
    }
    // Get rule output node's anchor peer node
    FusionRuleNodePtr rule_node = rule_anchor->GetPeerAnchors().at(0)->GetOwnerNode();

    // Get matched graph node form match_result.origin_nodes
    auto it = match_result.origin_nodes.find(rule_node);
    if (it == match_result.origin_nodes.end()) {
      return;
    }
    ge::NodePtr origin_graph_node = it->second;

    std::set<ge::AnchorPtr> outer_output_set = map_outer_output.second;
    if (outer_output_set.empty()) {
      return;
    }

    auto graph_input_anchor = std::static_pointer_cast<ge::InDataAnchor>(*(outer_output_set.begin()));

    // Get fusion node output anchor idx
    int32_t graph_out_anchor_idx = graph_input_anchor->GetPeerOutAnchor()->GetIdx();

    // Get fusion graph node
    ge::NodePtr fusion_graph_node = graph_input_anchor->GetPeerOutAnchor()->GetOwnerNode();

    // Set output desc
    (void)GraphPassUtil::SetOutputDescAttr(rule_out_anchor_idx, graph_out_anchor_idx, origin_graph_node,
                                           fusion_graph_node);
  }
}

ge::NodePtr GraphReplace::CreateNode(const FusionRuleNodePtr fusion_rule_node, const string &node_name,
                                     ge::ComputeGraph &graph) {
  if (fusion_rule_node->GetNodeType().empty()) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][CrtNd] Node type of FusionRuleNode is empty.");
    return nullptr;
  }
  string node_type = fusion_rule_node->GetNodeType()[0];
  size_t in_anchor_num = fusion_rule_node->GetInputDataAnchors().size();
  size_t out_anchor_num = fusion_rule_node->GetOutputDataAnchors().size();
  ge::GeTensorDesc tensor_desc;
  auto node_op = ge::OperatorFactory::CreateOperator(node_name, node_type);
  if (node_op.IsEmpty()) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][CrtNd] create fusion node[%s] error", node_type.c_str());
    return nullptr;
  }
  auto temp_opdesc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  node_op.BreakConnect();
  ge::OpDescPtr op_desc = ge::AttrUtils::CopyOpDesc(temp_opdesc);
  for (size_t i = op_desc->GetInputsSize(); i < in_anchor_num; ++i) {
    if (op_desc->AddInputDesc(tensor_desc) != ge::SUCCESS) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][CrtNd] Fail to add input desc for node[%s].", node_name.c_str());
      return nullptr;
    }
  }
  for (size_t i = op_desc->GetOutputsSize(); i < out_anchor_num; ++i) {
    if (op_desc->AddOutputDesc(tensor_desc) != ge::SUCCESS) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][CrtNd] Fail to add output desc for node[%s].", node_name.c_str());
      return nullptr;
    }
  }
  ge::NodePtr node = graph.AddNode(op_desc);
  NodeMapInfoPtr node_map_info = nullptr;
  (void)GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph);
  (void)GraphPassUtil::AddNodeFromOpTypeMap(node_map_info, node);

  return node;
}

string GraphReplace::CreateNodeName(const map<FusionRuleNodePtr, ge::NodePtr> &origin_sub_graph,
                                    const FusionRulePattern &fusion_rule_pattern, const vector<string> &types) {
  ostringstream fusion_node_name;
  if (origin_sub_graph.empty()) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][CrtNdNm] Origin Sub Graph is empty.");
    return fusion_node_name.str();
  }

  static int fusion_node_count = 0;
  ge::NodePtr origin_node = origin_sub_graph.begin()->second;
  string node_name = origin_node->GetOpDesc()->GetName();
  vector<string> name_vec = StringUtils::Split(node_name, '/');
  for (size_t i = 0; i < name_vec.size(); ++i) {
    fusion_node_name << (name_vec[i] + "/");
  }
  if (!types.empty()) {
    fusion_node_name << fusion_rule_pattern.GetRuleName() << "/" << types[0] << fusion_node_count;
  }
  fusion_node_count = (fusion_node_count + 1) % (ULLONG_MAX);
  return fusion_node_name.str();
}
ge::NodePtr GraphReplace::FindSameNode(const FusionRuleNodePtr fusion_rule_node,
                                       const map<FusionRuleNodePtr, ge::NodePtr> &origin_sub_graph) {
  string node_name = fusion_rule_node->GetNodeName();
  for (auto const &item : origin_sub_graph) {
    if (item.first->GetNodeName() == node_name) {
      return item.second;
    }
  }
  return nullptr;
}
Status GraphReplace::DeleteNodes(const map<FusionRuleNodePtr, ge::NodePtr> &nodes,
                                 const set<FusionRuleNodePtr> &rule_nodes, ge::ComputeGraph &graph) {
  NodeMapInfoPtr node_map_info = nullptr;
  (void)GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph);

  for (auto item = rule_nodes.begin(); item != rule_nodes.end(); ++item) {
    if (nodes.find(*item) == nodes.end()) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][DelNd] The node[%s] not in origin_sub_graph",
                      (*item)->GetNodeName().c_str());
      return FAILED;
    }

    ge::NodePtr node = nodes.at(*item);
    // delete input data anchors,
    for (size_t i = 0; i < node->GetAllInDataAnchors().size(); ++i) {
      auto in_data_anchor = node->GetInDataAnchor(i);
      FE_CHECK_NOTNULL(in_data_anchor);
      in_data_anchor->UnlinkAll();
    }
    // remove input control anchor
    auto in_control_anchor = node->GetInControlAnchor();
    FE_CHECK_NOTNULL(in_control_anchor);
    in_control_anchor->UnlinkAll();
    // remove node, RemoveNode function remove input and oupt anchor
    // in order to prevent automatic add edege when delete node, we should
    // remove input anchor firstly
    if (graph.RemoveNode(node) == ge::GRAPH_FAILED) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][DelNd] remove node[%s] error", node->GetName().c_str());
      return FAILED;
    }

    GraphNodeMapUtil::DelNodeFromOpTypeMap(node_map_info, node);
  }
  return SUCCESS;
}

Status GraphReplace::ReplaceInputAnchors(const FusionRuleNodePtr &rule_node, ge::NodePtr fusion_node,
                                         const map<FusionRuleAnchorPtr, std::set<ge::AnchorPtr>> &outer_inputs,
                                         const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph) {
  Status ret = ReplaceInputDataAnchors(rule_node, fusion_node, outer_inputs, fusion_sub_graph);
  if (ret == FAILED) {
    return ret;
  }

  ret = ReplaceInputCtrlAnchors(rule_node, fusion_node, outer_inputs, fusion_sub_graph);
  if (ret == FAILED) {
    return ret;
  }

  return SUCCESS;
}

Status GraphReplace::ReplaceInputCtrlAnchors(const FusionRuleNodePtr &rule_node, ge::NodePtr fusion_node,
                                             const map<FusionRuleAnchorPtr, std::set<ge::AnchorPtr>> &outer_inputs,
                                             const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph) {
  FE_LOGD("fused rule node:%s, graph node:%s", rule_node->GetNodeName().c_str(), fusion_node->GetName().c_str());
  FusionRuleAnchorPtr input_anchor = rule_node->GetInputCtrlAnchor();
  if (input_anchor == nullptr) {
    return SUCCESS;
  }

  std::set<ge::OutControlAnchorPtr> outer_ctrl_edges;
  ge::NodePtr peer_node = nullptr;
  ge::OutControlAnchorPtr src_anchor = nullptr;
  ge::InControlAnchorPtr dst_anchor = fusion_node->GetInControlAnchor();

  for (auto &peer_anchor : input_anchor->GetPeerAnchors()) {
    FusionRuleNodePtr peer_rule_node = peer_anchor->GetOwnerNode();
    // if the edge is between fusion node and fusion node, find peer node in
    // fusion sub graph
    if (outer_inputs.empty() && fusion_sub_graph.find(peer_rule_node) != fusion_sub_graph.end()) {
      peer_node = fusion_sub_graph.at(peer_rule_node);
      src_anchor = peer_node->GetOutControlAnchor();
      if (ge::GraphUtils::AddEdge(src_anchor, dst_anchor) == ge::GRAPH_FAILED) {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplInCtrlAncr] add in ctrl edge from node[%s] to node[%s] failed",
                        peer_node->GetName().c_str(), fusion_node->GetName().c_str());
        return FAILED;
      }
    } else if (!outer_inputs.empty() && fusion_sub_graph.find(peer_rule_node) == fusion_sub_graph.end()) {
      // if the edge is between fusion node and outer input node, find peer
      // anchor in pre-fusion sub graph
      auto peer_anchors_pair = outer_inputs.find(peer_anchor);
      if (peer_anchors_pair == outer_inputs.end()) {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplInCtrlAncr] outer input anchor[%s] not in match_result",
                        peer_anchor->GetAnchorName().c_str());
        return FAILED;
      }
      for (auto &peer_graph_anchor : peer_anchors_pair->second) {
        src_anchor = dynamic_pointer_cast<ge::OutControlAnchor>(peer_graph_anchor);
        outer_ctrl_edges.insert(src_anchor);
      }
    } else {
      // has been linked or do not need to add
      continue;
    }
  }

  for (auto &src_out_anchor : outer_ctrl_edges) {
    if (ge::GraphUtils::AddEdge(src_out_anchor, dst_anchor) == ge::GRAPH_FAILED) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplInCtrlAncr] add outer ctrl edge from node[%s] to node[%s] failed",
                      peer_node->GetName().c_str(), fusion_node->GetName().c_str());
      return FAILED;
    }
  }

  return SUCCESS;
}

Status GraphReplace::ReplaceInputDataAnchors(const FusionRuleNodePtr &rule_node, ge::NodePtr fusion_node,
                                             const map<FusionRuleAnchorPtr, std::set<ge::AnchorPtr>> &outer_inputs,
                                             const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph) {
  const vector<FusionRuleAnchorPtr> &input_anchor = rule_node->GetInputDataAnchors();
  ge::NodePtr peer_node = nullptr;
  ge::OutDataAnchorPtr src_anchor = nullptr;
  ge::InDataAnchorPtr dst_anchor = nullptr;
  vector<bool> is_input_const(input_anchor.size(), false);
  vector<bool> pre_is_input_const = fusion_node->GetOpDesc()->GetIsInputConst();
  if (!pre_is_input_const.empty()) {
    is_input_const = pre_is_input_const;
  }
  for (auto const &item : input_anchor) {
    vector<FusionRuleAnchorPtr> peer_anchors = item->GetPeerAnchors();
    int input_index = item->GetAnchorIdx();
    // peer anchor number must be 1, because the input anchor corresponds to
    // only one out anchor
    if (peer_anchors.size() != 1) {
      REPORT_FE_ERROR(
          "[GraphOpt][RunFusionRule][RplInDataAncr] fusion node[%s] input anchor corresponds output \
	  				  anchor number not 1",
          fusion_node->GetName().c_str());
      return FAILED;
    }

    FusionRuleNodePtr peer_rule_node = peer_anchors[0]->GetOwnerNode();
    // if the edge is between fusion node and fusion node, find peer node in
    // fusion sub graph
    if (fusion_sub_graph.find(peer_rule_node) != fusion_sub_graph.end() && outer_inputs.empty()) {
      peer_node = fusion_sub_graph.at(peer_rule_node);
      int peer_output_index = peer_anchors[0]->GetAnchorIdx();
      src_anchor = peer_node->GetOutDataAnchor(peer_output_index);
    } else if (fusion_sub_graph.find(peer_rule_node) == fusion_sub_graph.end() &&
               !outer_inputs.empty()) {  // if the edge is between fusion node
                                         // and outer input node, find peer
                                         // anchor in pre-fusion sub graph
      auto peer_anchors_pair = outer_inputs.find(peer_anchors[0]);
      if (peer_anchors_pair == outer_inputs.end()) {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplInDataAncr] outer input anchor[%s] not in match_result",
                        peer_anchors[0]->GetAnchorName().c_str());
        return FAILED;
      }
      if (!peer_anchors_pair->second.empty()) {
        ge::AnchorPtr peer_anchor = *(peer_anchors_pair->second.begin());
        src_anchor = dynamic_pointer_cast<ge::OutDataAnchor>(peer_anchor);
        peer_node = peer_anchor->GetOwnerNode();
      }

    } else {
      continue;
    }
    dst_anchor = fusion_node->GetInDataAnchor(input_index);
    // Because the output anchor can correspond to multiple inputs, we can
    // directly add edges
    if (ge::GraphUtils::AddEdge(src_anchor, dst_anchor) == ge::GRAPH_FAILED) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplInDataAncr] add edge from node[%s] to node[%s] error",
                      peer_node->GetName().c_str(), fusion_node->GetName().c_str());
      return FAILED;
    }

    if (peer_node->GetType() == CONSTANT) {
      is_input_const[input_index] = true;
    }
  }
  fusion_node->GetOpDesc()->SetIsInputConst(is_input_const);
  return SUCCESS;
}

Status GraphReplace::ReplaceOutputCtrlAnchors(const FusionRuleNodePtr &rule_node, ge::NodePtr fusion_node,
                                              const map<FusionRuleAnchorPtr, set<ge::AnchorPtr>> &outer_outputs,
                                              const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph) {
  FusionRuleAnchorPtr output_anchor = rule_node->GetOutputCtrlAnchor();
  if (output_anchor == nullptr) {
    return SUCCESS;
  }

  const vector<FusionRuleAnchorPtr> &peer_anchors = output_anchor->GetPeerAnchors();
  ge::OutControlAnchorPtr src_anchor = fusion_node->GetOutControlAnchor();
  for (size_t i = 0; i < peer_anchors.size(); ++i) {
    FusionRuleAnchorPtr peer_anchor = peer_anchors[i];
    FusionRuleNodePtr peer_rule_node = peer_anchor->GetOwnerNode();
    // if the edge is between fusion node and fusion node, the edge had been
    // linked at ReplaceInputAnchors func if the edge is between fusion node
    // and origin graph node, find peer anchor in pre-fusion sub graph
    if (fusion_sub_graph.find(peer_rule_node) == fusion_sub_graph.end()) {
      if (outer_outputs.find(peer_anchor) == outer_outputs.end()) {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplOutCtrlAncr] outer output anchor[%s] not in match result",
                        peer_anchor->GetAnchorName().c_str());
        return FAILED;
      }
      set<ge::AnchorPtr> outer_anchors = outer_outputs.at(peer_anchor);
      if (LinkOuterOutputEdges(src_anchor, outer_anchors) == FAILED) {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplOutCtrlAncr] link fusion node with outer node failed");
        return FAILED;
      }
    }
  }

  return SUCCESS;
}

Status GraphReplace::ReplaceOutputDataAnchors(const FusionRuleNodePtr &rule_node, ge::NodePtr fusion_node,
                                              const map<FusionRuleAnchorPtr, set<ge::AnchorPtr>> &outer_outputs,
                                              const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph) {
  const vector<FusionRuleAnchorPtr> &output_anchor = rule_node->GetOutputDataAnchors();
  ge::OutDataAnchorPtr src_anchor = nullptr;
  for (auto item = output_anchor.begin(); item != output_anchor.end(); ++item) {
    const vector<FusionRuleAnchorPtr> &peer_anchors = (*item)->GetPeerAnchors();
    int output_index = (*item)->GetAnchorIdx();
    src_anchor = fusion_node->GetOutDataAnchor(output_index);
    for (size_t i = 0; i < peer_anchors.size(); ++i) {
      FusionRuleAnchorPtr peer_anchor = peer_anchors[i];
      FusionRuleNodePtr peer_rule_node = peer_anchor->GetOwnerNode();
      // if the edge is between fusion node and fusion node, the edge had been
      // linked at ReplaceInputAnchors func if the edge is between fusion node
      // and origin graph node, find peer anchor in pre-fusion sub graph
      if (fusion_sub_graph.find(peer_rule_node) == fusion_sub_graph.end()) {
        if (outer_outputs.find(peer_anchor) == outer_outputs.end()) {
          REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplOutDataAncr] outer output anchor[%s] not in match result",
                          peer_anchor->GetAnchorName().c_str());
          return FAILED;
        }

        set<ge::AnchorPtr> outer_anchors = outer_outputs.at(peer_anchor);
        if (LinkOuterOutputEdges(src_anchor, outer_anchors) == FAILED) {
          REPORT_FE_ERROR("[GraphOpt][RunFusionRule][RplOutDataAncr] link fusion node with outer node error");
          return FAILED;
        }
      }
    }
  }
  return SUCCESS;
}

Status GraphReplace::ReplaceOutputAnchors(const FusionRuleNodePtr &rule_node, ge::NodePtr fusion_node,
                                          const map<FusionRuleAnchorPtr, set<ge::AnchorPtr>> &outer_outputs,
                                          const map<FusionRuleNodePtr, ge::NodePtr> &fusion_sub_graph) {
  Status ret = ReplaceOutputDataAnchors(rule_node, fusion_node, outer_outputs, fusion_sub_graph);
  if (ret == FAILED) {
    return ret;
  }

  ret = ReplaceOutputCtrlAnchors(rule_node, fusion_node, outer_outputs, fusion_sub_graph);
  if (ret == FAILED) {
    return ret;
  }

  return SUCCESS;
}

Status GraphReplace::LinkOuterOutputEdges(ge::AnchorPtr src_anchor, const set<ge::AnchorPtr> &outer_anchors) {
  for (ge::AnchorPtr const &outer_anchor : outer_anchors) {
    // outer_anchor maybe control input anchor or data input anchor
    if (ge::GraphUtils::AddEdge(src_anchor, outer_anchor) == ge::GRAPH_FAILED) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][LkOutOutEdge] add data edge from node[%s] to node[%s] error",
                      src_anchor->GetOwnerNode()->GetName().c_str(), outer_anchor->GetOwnerNode()->GetName().c_str());
      return FAILED;
    }
  }
  return SUCCESS;
}

Status GraphReplace::CheckFusionNode(GraphMatchResult &match_result, const FusionRulePattern &fusion_rule_pattern,
                                     const map<FusionRuleNodePtr, ge::NodePtr> &fusion_nodes) {
  vector<ge::NodePtr> sort_nodes;
  if (!LinkFusionNode(fusion_nodes)) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkFusNd] link the edges between the fusion nodes failed");
    return FAILED;
  }
  if (!TopoSortFusionNode(fusion_nodes, sort_nodes)) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkFusNd] sort fusion node failed");
    return FAILED;
  }
  map<ge::NodePtr, FusionRuleNodePtr> search_nodes;
  for (auto const &iter : fusion_nodes) {
    search_nodes[iter.second] = iter.first;
  }

  if (!LinkOuterInputsEdge(fusion_nodes, match_result.outer_inputs_set)) {
    REPORT_FE_ERROR(
        "[GraphOpt][RunFusionRule][ChkFusNd] link the edges between the fusion nodes and outer nodes failed");
    return FAILED;
  }
  if (!InferShapeDtypeAndFormat(sort_nodes, search_nodes, match_result.outer_outputs)) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkFusNd] infer shape, data type or origin format failed");
    return FAILED;
  }
  if (!CheckSupported(sort_nodes)) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkFusNd] check shape and data type support failed");
    return FAILED;
  }
  if (!CheckShapeAndTypeContinuous(sort_nodes, search_nodes, match_result.outer_outputs)) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkFusNd] check shape and data type support failed");
    return FAILED;
  }
  return SUCCESS;
}

bool GraphReplace::LinkFusionNode(const map<FusionRuleNodePtr, ge::NodePtr> &fusion_nodes) {
  for (auto &item : fusion_nodes) {
    FusionRuleNodePtr rule_node = item.first;
    ge::NodePtr fusion_node = item.second;
    FE_LOGD("link fused rule node:%s, graph node:%s", rule_node->GetNodeName().c_str(), fusion_node->GetName().c_str());
    if (ReplaceInputAnchors(rule_node, fusion_node, {}, fusion_nodes) == FAILED) {
      return false;
    }
  }
  return true;
}

bool GraphReplace::TopoSortFusionNode(const map<FusionRuleNodePtr, ge::NodePtr> &fusion_nodes,
                                      vector<ge::NodePtr> &sort_nodes) {
  stack<ge::NodePtr> node_stack;
  map<ge::NodePtr, int> node_inputs_map;
  // find the node without input, and compute the input numbers of other nodes
  for (auto const &item : fusion_nodes) {
    ge::NodePtr node = item.second;
    int input_size = 0;
    for (ge::InDataAnchorPtr const &anchor : node->GetAllInDataAnchors()) {
      if (anchor != nullptr) {
        if (CheckInt32AddOverflow(input_size, anchor->GetPeerAnchors().size()) == FAILED) {
          REPORT_FE_ERROR("[GraphOpt][RunFusionRule][TpsrFusNd] fusion node:%s peer anchor size is too much.",
                          node->GetName().c_str());
          return false;
        }
        input_size += anchor->GetPeerAnchors().size();
      }
    }
    if (input_size == 0) {
      node_stack.push(node);
    } else {
      node_inputs_map[node] = input_size;
    }
  }
  while (!node_stack.empty()) {
    ge::NodePtr node = node_stack.top();
    node_stack.pop();
    sort_nodes.push_back(node);
    FE_LOGD("sort nodes push back node[%s]", node->GetName().c_str());
    // the value of node_input_map is node's input number,
    // when the value is zero, representing the parent nodes of this node
    // have been visited and this node can be sorted
    for (ge::OutDataAnchorPtr const &anchor : node->GetAllOutDataAnchors()) {
      for (ge::InDataAnchorPtr const &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
        auto iter = node_inputs_map.find(peer_in_anchor->GetOwnerNode());
        if (iter != node_inputs_map.end() && --(iter->second) == 0) {
          node_stack.push(peer_in_anchor->GetOwnerNode());
        }
      }
    }
  }
  if (sort_nodes.size() != fusion_nodes.size()) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][TpsrFusNd] sort nodes size not equal with fusion nodes");
    return false;
  }
  return true;
}

bool GraphReplace::LinkOuterInputsEdge(const map<FusionRuleNodePtr, ge::NodePtr> &fusion_nodes,
                                       const map<FusionRuleAnchorPtr, std::set<ge::AnchorPtr>> &outer_inputs) {
  for (auto const &item : fusion_nodes) {
    FusionRuleNodePtr rule_node = item.first;
    ge::NodePtr fusion_node = item.second;
    FE_LOGD("linkouter fused rule node:%s, graph node:%s", rule_node->GetNodeName().c_str(),
            fusion_node->GetName().c_str());
    if (ReplaceInputAnchors(rule_node, fusion_node, outer_inputs, fusion_nodes) == FAILED) {
      return false;
    }
  }
  return true;
}

void UpdateInputDescForPeerNode(const ge::NodePtr &node, const map<ge::NodePtr, FusionRuleNodePtr> &fusion_nodes) {
  for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
    for (const auto &peer_out_anchor : out_anchor->GetPeerInDataAnchors()) {
      auto peer_node = peer_out_anchor->GetOwnerNode();
      if (fusion_nodes.find(peer_node) != fusion_nodes.end()) {
        auto output_tensor = node->GetOpDesc()->GetOutputDescPtr(out_anchor->GetIdx());
        FE_LOGD("the output desc of the node [%s]: format[%u], origin_format[%u], dtype[%u], shape[%s].",
                node->GetName().c_str(), output_tensor->GetFormat(), output_tensor->GetOriginFormat(),
                output_tensor->GetDataType(), GetShapeDims(output_tensor->GetShape()).c_str());
        auto peer_opdesc = peer_node->GetOpDesc();
        peer_opdesc->UpdateInputDesc(peer_out_anchor->GetIdx(), *output_tensor);
      }
    }
  }
}

bool InferShape(const ge::NodePtr &node) {
  FE_LOGI("node %s: start to InferShapeAndType.", node->GetName().c_str());
  if (node->InferShapeAndType() != ge::SUCCESS) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][InferShape] node[%s] InferShapeAndType failed", node->GetName().c_str());
    return false;
  }

  FE_LOGI("node %s: start to InferOriginFormat.", node->GetName().c_str());
  if (node->InferOriginFormat() != ge::SUCCESS) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][InferShape] node[%s] InferOriginFormat failed", node->GetName().c_str());
    return false;
  }
  return true;
}

bool GraphReplace::InferShapeDtypeAndFormat(const vector<ge::NodePtr> &sort_nodes,
                                            const map<ge::NodePtr, FusionRuleNodePtr> &fusion_nodes,
                                            const map<FusionRuleAnchorPtr, set<ge::AnchorPtr>> &outer_outputs) {
  for (auto const &node : sort_nodes) {
    for (ge::InDataAnchorPtr const &anchor : node->GetAllInDataAnchors()) {
      ge::OutDataAnchorPtr peer_anchor = anchor->GetPeerOutAnchor();
      if (peer_anchor != nullptr) {
        // if the input of fusion node is outer node, update output tensor to
        // input opdesc
        ge::NodePtr peer_node = peer_anchor->GetOwnerNode();
        if (fusion_nodes.find(peer_node) == fusion_nodes.end()) {
          auto output_tensor = peer_node->GetOpDesc()->GetOutputDescPtr(peer_anchor->GetIdx());
          node->GetOpDesc()->UpdateInputDesc(anchor->GetIdx(), *output_tensor);
        }
      } else {
        REPORT_FE_ERROR("[GraphOpt][RunFusionRule][InfShpDtpFmt] node[%s] peer anchor [%d] is null",
                        node->GetName().c_str(), anchor->GetIdx());
        return false;
      }
    }
    // if there is only one fusion node update output tensor with its child
    // nodes input tensor
    if (sort_nodes.size() == 1) {
      auto rule_node = fusion_nodes.at(node);
      const vector<FusionRuleAnchorPtr> &output_anchor = rule_node->GetOutputDataAnchors();
      for (const auto &out_anchor : output_anchor) {
        int output_index = out_anchor->GetAnchorIdx();
        for (const auto &peer_rule_input_anchor : out_anchor->GetPeerAnchors()) {
          if (outer_outputs.find(peer_rule_input_anchor) == outer_outputs.end()) {
            REPORT_FE_ERROR("[GraphOpt][RunFusionRule][InfShpDtpFmt] node[%s] does not have output anchor",
                            node->GetName().c_str());
            return false;
          }
          auto input_anchors = outer_outputs.at(peer_rule_input_anchor);
          for (auto const &input_anchor : input_anchors) {
            ge::InDataAnchorPtr dst_data_anchor = ge::Anchor::DynamicAnchorCast<ge::InDataAnchor>(input_anchor);
            if (dst_data_anchor == nullptr) {
              continue;
            }
            auto peer_node = dst_data_anchor->GetOwnerNode();
            auto input_tensor = peer_node->GetOpDesc()->GetInputDescPtr(dst_data_anchor->GetIdx());
            if (node->GetOpDesc()->UpdateOutputDesc(output_index, *input_tensor) == ge::GRAPH_FAILED) {
              REPORT_FE_ERROR("[GraphOpt][RunFusionRule][InfShpDtpFmt] node[%s] update output desc failed",
                              node->GetName().c_str());
              return false;
            }
            break;
          }
        }
      }
      return true;
    }
    // 3. if there are many fusion nodes: InferShapeAndType and InferFormat for
    // the output desc of the fusion node
    if (!InferShape(node)) {
      return false;
    }

    // 4. update the input desc for the peer node
    UpdateInputDescForPeerNode(node, fusion_nodes);
  }
  return true;
}

bool GraphReplace::CheckSupported(const vector<ge::NodePtr> &sort_nodes) {
  for (auto const &node : sort_nodes) {
    auto opdesc = node->GetOpDesc();
    if (IsPlaceOrEnd(opdesc->GetType())) {
      continue;
    }
    string un_supported_reason;
    if (ops_kernel_info_store_ptr_->CheckSupported(opdesc, un_supported_reason) == false) {
      REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkSpt] Node[%s] check_shape_and_type failed, reason is %s",
                      node->GetName().c_str(), un_supported_reason.c_str());
      return false;
    }
  }
  return true;
}

bool GraphReplace::CheckShapeAndTypeContinuous(const vector<ge::NodePtr> &sort_nodes,
                                               const map<ge::NodePtr, FusionRuleNodePtr> &fusion_nodes,
                                               const map<FusionRuleAnchorPtr, set<ge::AnchorPtr>> &outer_outputs) {
  // if there is only one fusion node, the input tensor is the same as parent
  // node,
  // the output tensor is the same as child node
  if (sort_nodes.size() == 1) {
    return true;
  }
  for (auto const &node : sort_nodes) {
    auto rule_node = fusion_nodes.at(node);
    const vector<FusionRuleAnchorPtr> &output_anchor = rule_node->GetOutputDataAnchors();
    for (const auto &out_anchor : output_anchor) {
      const vector<FusionRuleAnchorPtr> &peer_rule_anchors = out_anchor->GetPeerAnchors();
      int output_index = out_anchor->GetAnchorIdx();
      auto output_data_anchor = node->GetOutDataAnchor(output_index);
      for (const auto &peer_rule_input_anchor : peer_rule_anchors) {
        if (outer_outputs.find(peer_rule_input_anchor) == outer_outputs.end()) {
          continue;
        }
        auto input_anchors = outer_outputs.at(peer_rule_input_anchor);
        for (auto const &input_anchor : input_anchors) {
          ge::InDataAnchorPtr dst_data_anchor = ge::Anchor::DynamicAnchorCast<ge::InDataAnchor>(input_anchor);
          // control anchor
          if (dst_data_anchor == nullptr) {
            continue;
          }
          // fusion node output shape and data type should be same with peer
          // outer node input shape and data type
          if (!CheckDataType(output_data_anchor, dst_data_anchor)) {
            REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkShpTypeContus] node[%s] check_data_type failed",
                            node->GetName().c_str());
            return false;
          }

          if (!CheckShape(output_data_anchor, dst_data_anchor)) {
            REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkShpTypeContus] node[%s] check_shape failed",
                            node->GetName().c_str());
            return false;
          }
        }
      }
    }
  }
  return true;
}
bool GraphReplace::CheckShape(ge::OutDataAnchorPtr out_anchor, ge::InDataAnchorPtr peer_in_anchor) {
  auto peer_node = peer_in_anchor->GetOwnerNode();
  auto node = out_anchor->GetOwnerNode();
  auto opdesc = node->GetOpDesc();
  auto peer_opdesc = peer_node->GetOpDesc();
  ge::ConstGeTensorDescPtr output_desc_ptr = opdesc->GetOutputDescPtr(out_anchor->GetIdx());
  ge::ConstGeTensorDescPtr input_desc_ptr = peer_opdesc->GetInputDescPtr(peer_in_anchor->GetIdx());
  if (output_desc_ptr == nullptr || input_desc_ptr == nullptr) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkShp] node [%s] output or peer input is null",
                    opdesc->GetName().c_str());
    return false;
  }
  ge::GeShape output_shape = output_desc_ptr->GetShape();
  ge::GeShape input_shape = input_desc_ptr->GetShape();
  if (output_shape.GetDims() != input_shape.GetDims()) {
    ostringstream output_shape_str;
    ostringstream input_shape_str;
    auto output_dims = output_shape.GetDims();
    auto input_dims = input_shape.GetDims();
    for (auto item : output_dims) {
      output_shape_str << item << ',';
    }
    for (auto item : input_dims) {
      input_shape_str << item << ',';
    }
    FE_LOGW("node [%s] output[%d] shape[%s] should be equal with node[%s] input [%d] shape [%s]",
            node->GetName().c_str(), out_anchor->GetIdx(), output_shape_str.str().c_str(), peer_node->GetName().c_str(),
            peer_in_anchor->GetIdx(), input_shape_str.str().c_str());
    return false;
  }
  return true;
}
bool GraphReplace::CheckDataType(ge::OutDataAnchorPtr out_anchor, ge::InDataAnchorPtr peer_in_anchor) {
  auto peer_node = peer_in_anchor->GetOwnerNode();
  auto node = out_anchor->GetOwnerNode();
  auto opdesc = node->GetOpDesc();
  auto peer_opdesc = peer_node->GetOpDesc();
  ge::ConstGeTensorDescPtr output_desc_ptr = opdesc->GetOutputDescPtr(out_anchor->GetIdx());
  ge::ConstGeTensorDescPtr input_desc_ptr = peer_opdesc->GetInputDescPtr(peer_in_anchor->GetIdx());
  if (output_desc_ptr == nullptr || input_desc_ptr == nullptr) {
    REPORT_FE_ERROR("[GraphOpt][RunFusionRule][ChkDatatype] node [%s] output or peer input is null",
                    opdesc->GetName().c_str());
    return false;
  }
  ge::DataType output_dtype = output_desc_ptr->GetDataType();
  ge::DataType input_dtype = input_desc_ptr->GetDataType();
  if (output_dtype != input_dtype) {
    FE_LOGW("node [%s] output [%d] data type [%d] should be equal with node [%s] input [%d] data type [%d]",
            node->GetName().c_str(), out_anchor->GetIdx(), output_dtype, peer_node->GetName().c_str(),
            peer_in_anchor->GetIdx(), input_dtype);
    return false;
  }
  return true;
}
}  // namespace fe
